{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate synthetic text classification data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **Goal**: Generate synthetic text classification data to augment an imbalanced and limited dataset for training a topic classifier. In addition, generate new data for training a fact-based versus opinion-based classifier to add a new label.\n",
    "- **Libraries**: [argilla](https://github.com/argilla-io/argilla), [hf-inference-endpoints](https://github.com/huggingface/huggingface_hub), [SetFit](https://github.com/huggingface/setfit)\n",
    "- **Components**: [LoadDataFromDicts](https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromdicts/), [EmbeddingTaskGenerator](https://distilabel.argilla.io/latest/components-gallery/tasks/embeddingtaskgenerator/), [GenerateTextClassificationData](https://distilabel.argilla.io/latest/components-gallery/tasks/generatetextclassificationdata/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting started\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install the dependencies\n",
    "\n",
    "To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using **the free but rate-limited Hugging Face serverless Inference API** for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"distilabel[hf-inference-endpoints]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"transformers~=4.40\" \"torch~=2.0\" \"setfit~=1.0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make the required imports:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from collections import Counter\n",
    "\n",
    "from datasets import load_dataset, Dataset\n",
    "from distilabel.models import InferenceEndpointsLLM\n",
    "from distilabel.pipeline import Pipeline\n",
    "from distilabel.steps import LoadDataFromDicts\n",
    "from distilabel.steps.tasks import (\n",
    "    GenerateTextClassificationData,\n",
    ")\n",
    "from setfit import SetFitModel, Trainer, sample_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You'll need an `HF_TOKEN` to use the HF Inference Endpoints. Log in to use it directly within this notebook.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from huggingface_hub import login\n",
    "\n",
    "login(token=os.getenv(\"HF_TOKEN\"), add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (optional) Deploy Argilla\n",
    "\n",
    "You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).\n",
    "\n",
    "Along with that, you will need to install Argilla as a distilabel extra.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"distilabel[argilla, hf-inference-endpoints]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The dataset\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the [`fancyzhx/ag_news`](https://huggingface.co/datasets/fancyzhx/ag_news) dataset from the Hugging Face Hub as our original data source. To simulate a real-world scenario with imbalanced and limited data, we will load only 20 samples from this dataset.\n",
    "\n",
    "<iframe\n",
    "  src=\"https://huggingface.co/datasets/fancyzhx/ag_news/embed/viewer/default/train\"\n",
    "  frameborder=\"0\"\n",
    "  width=\"100%\"\n",
    "  height=\"560px\"\n",
    "></iframe>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_dataset = load_dataset(\"fancyzhx/ag_news\", split=\"train[-20:]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can retrieve the available labels in the dataset and examine the current data distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}\n",
      "Counter({0: 12, 1: 6, 2: 2})\n"
     ]
    }
   ],
   "source": [
    "labels_topic = hf_dataset.features[\"label\"].names\n",
    "id2str = {i: labels_topic[i] for i in range(len(labels_topic))}\n",
    "print(id2str)\n",
    "print(Counter(hf_dataset[\"label\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As observed, the dataset is imbalanced, with most samples falling under the `World` category, while the `Sci/Tech` category is entirely missing. Moreover, there are insufficient samples to effectively train a topic classification model.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will also define the labels for the new classification task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_fact_opinion = [\"Fact-based\", \"Opinion-based\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the text classification task\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To generate the data we will use the `GenerateTextClassificationData` task. This task will use as input classification tasks and we can define the language, difficulty and clarity required for the generated data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'role': 'user', 'content': 'You have been assigned a text classification task: Classify the news article as fact-based or opinion-based\\n\\nYour mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:\\n - \"input_text\": a string, the input text specified by the classification task.\\n - \"label\": a string, the correct label of the input text.\\n - \"misleading_label\": a string, an incorrect label that is related to the task.\\n\\nPlease adhere to the following guidelines:\\n - The \"input_text\" should be diverse in expression.\\n - The \"misleading_label\" must be a valid label for the given task, but not as appropriate as the \"label\" for the \"input_text\".\\n - The values for all fields should be in English.\\n - Avoid including the values of the \"label\" and \"misleading_label\" fields in the \"input_text\", that would make the task too easy.\\n - The \"input_text\" is clear and requires college level education to comprehend.\\n\\nYour output must always be a JSON object only, do not explain yourself or output anything else. Be creative!'}]\n"
     ]
    }
   ],
   "source": [
    "task = GenerateTextClassificationData(\n",
    "    language=\"English\",\n",
    "    difficulty=\"college\",\n",
    "    clarity=\"clear\",\n",
    "    num_generations=1,\n",
    "    llm=InferenceEndpointsLLM(\n",
    "        model_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
    "        tokenizer_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
    "        generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.4},\n",
    "    ),\n",
    "    input_batch_size=5,\n",
    ")\n",
    "task.load()\n",
    "result = next(\n",
    "    task.process([{\"task\": \"Classify the news article as fact-based or opinion-based\"}])\n",
    ")\n",
    "print(result[0][\"distilabel_metadata\"][\"raw_input_generate_text_classification_data_0\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For our use case, we only need to generate data for two tasks: a topic classification task and a fact versus opinion classification task. Therefore, we will define the tasks accordingly. As we will be using an smaller model for generation, we will select 2 random labels for each topic classification task and change the order for the fact versus opinion classification task ensuring more diversity in the generated data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_templates = [\n",
    "    \"Determine the news article as {}\",\n",
    "    \"Classify news article as {}\",\n",
    "    \"Identify the news article as {}\",\n",
    "    \"Categorize the news article as {}\",\n",
    "    \"Label the news article using {}\",\n",
    "    \"Annotate the news article based on {}\",\n",
    "    \"Determine the theme of a news article from {}\",\n",
    "    \"Recognize the topic of the news article as {}\",\n",
    "]\n",
    "\n",
    "classification_tasks = [\n",
    "    {\"task\": action.format(\" or \".join(random.sample(labels_topic, 2)))}\n",
    "    for action in task_templates for _ in range(4)\n",
    "] + [\n",
    "    {\"task\": action.format(\" or \".join(random.sample(labels_fact_opinion, 2)))}\n",
    "    for action in task_templates\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the pipeline\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, it's time to define and run the pipeline. As mentioned, we will load the written tasks and feed them into the `GenerateTextClassificationData` task. For our use case, we will be using `Meta-Llama-3.1-8B-Instruct` via the `InferenceEndpointsLLM`, with different degrees of difficulty and clarity.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "difficulties = [\"college\", \"high school\", \"PhD\"]\n",
    "clarity = [\"clear\", \"understandable with some effort\", \"ambiguous\"]\n",
    "\n",
    "with Pipeline(\"texcat-generation-pipeline\") as pipeline:\n",
    "\n",
    "    tasks_generator = LoadDataFromDicts(data=classification_tasks)\n",
    "\n",
    "    generate_data = []\n",
    "    for difficulty in difficulties:\n",
    "        for clarity_level in clarity:\n",
    "            task = GenerateTextClassificationData(\n",
    "                language=\"English\",\n",
    "                difficulty=difficulty,\n",
    "                clarity=clarity_level,\n",
    "                num_generations=2,\n",
    "                llm=InferenceEndpointsLLM(\n",
    "                    model_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
    "                    tokenizer_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
    "                    generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
    "                ),\n",
    "                input_batch_size=5,\n",
    "            )\n",
    "            generate_data.append(task)\n",
    "\n",
    "    for task in generate_data:\n",
    "        tasks_generator.connect(task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now run the pipeline and generate the synthetic data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distiset = pipeline.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'task': 'Determine the news article as Business or World',\n",
       " 'input_text': \"The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone's economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.\",\n",
       " 'label': 'Business',\n",
       " 'misleading_label': 'World',\n",
       " 'distilabel_metadata': {'raw_output_generate_text_classification_data_0': '{\\n  \"input_text\": \"The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone\\'s economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.\",\\n  \"label\": \"Business\",\\n  \"misleading_label\": \"World\"\\n}'},\n",
       " 'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct'}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "distiset[\"generate_text_classification_data_0\"][\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can push the dataset to the Hub for sharing with the community and [embed it to explore the data](https://huggingface.co/docs/hub/datasets-viewer-embed).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distiset.push_to_hub(\"[your-owner-name]/example-texcat-generation-dataset\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<iframe\n",
    "  src=\"https://huggingface.co/datasets/distilabel-internal-testing/example-texcat-generation-dataset/embed/viewer/generate_text_classification_data_1/train\"\n",
    "  frameborder=\"0\"\n",
    "  width=\"100%\"\n",
    "  height=\"560px\"\n",
    "></iframe>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By examining the distiset distribution, we can confirm that it includes at least the 8 required samples for each label to train our classification models with SetFit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Sci/Tech': 275,\n",
       "         'Business': 130,\n",
       "         'World': 86,\n",
       "         'Fact-based': 86,\n",
       "         'Sports': 64,\n",
       "         'Opinion-based': 54,\n",
       "         None: 20,\n",
       "         'Opinion Based': 1,\n",
       "         'News/Opinion': 1,\n",
       "         'Science': 1,\n",
       "         'Environment': 1,\n",
       "         'Opinion': 1})"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "all_labels = [\n",
    "    entry[\"label\"]\n",
    "    for dataset_name in distiset\n",
    "    for entry in distiset[dataset_name][\"train\"]\n",
    "]\n",
    "\n",
    "Counter(all_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will create two datasets with the required labels and data for our use cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_rows(distiset, labels):\n",
    "    return [\n",
    "        {\n",
    "            \"text\": entry[\"input_text\"],\n",
    "            \"label\": entry[\"label\"],\n",
    "            \"id\": i\n",
    "        }\n",
    "        for dataset_name in distiset\n",
    "        for i, entry in enumerate(distiset[dataset_name][\"train\"])\n",
    "        if entry[\"label\"] in labels\n",
    "    ]\n",
    "\n",
    "data_topic = extract_rows(distiset, labels_topic)\n",
    "data_fact_opinion = extract_rows(distiset, labels_fact_opinion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Optional) Evaluate with Argilla\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "!!! note \"Get started in Argilla\"\n",
    "    If you are not familiar with Argilla, we recommend taking a look at the [Argilla quickstart docs](https://docs.argilla.io/latest/getting_started/quickstart/). Alternatively, you can use your Hugging Face account to login to the [Argilla demo Space](https://argilla-argilla-template-space.hf.space).\n",
    "\n",
    "To get the most out of our data, we will use Argilla. First, we need to connect to the Argilla instance.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argilla as rg\n",
    "\n",
    "# Replace api_url with your url if using Docker\n",
    "# Replace api_key with your API key under \"My Settings\" in the UI\n",
    "# Uncomment the last line and set your HF_TOKEN if your space is private\n",
    "client = rg.Argilla(\n",
    "    api_url=\"https://[your-owner-name]-[your_space_name].hf.space\",\n",
    "    api_key=\"[your-api-key]\",\n",
    "    # headers={\"Authorization\": f\"Bearer {HF_TOKEN}\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will create a `Dataset` for each task, with an input `TextField` for the text classification text and a `LabelQuestion` to ensure the generated labels are correct.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_texcat_dataset(dataset_name, labels):\n",
    "    settings = rg.Settings(\n",
    "        fields=[rg.TextField(\"text\")],\n",
    "        questions=[\n",
    "            rg.LabelQuestion(\n",
    "                name=\"label\",\n",
    "                title=\"Classify the texts according to the following labels\",\n",
    "                labels=labels,\n",
    "            ),\n",
    "        ],\n",
    "    )\n",
    "    return rg.Dataset(name=dataset_name, settings=settings).create()\n",
    "\n",
    "\n",
    "rg_dataset_topic = create_texcat_dataset(\"topic-classification\", labels_topic)\n",
    "rg_dataset_fact_opinion = create_texcat_dataset(\n",
    "    \"fact-opinion-classification\", labels_fact_opinion\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can upload the generated data to Argilla and evaluate it. We will use the generated labels as suggestions.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rg_dataset_topic.records.log(data_topic)\n",
    "rg_dataset_fact_opinion.records.log(data_fact_opinion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records. If the suggestions are correct, you can just click on `Submit`. Otherwise, you can select the correct label.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "!!! note\n",
    "    Check this [how-to guide](https://docs.argilla.io/latest/how_to_guides/annotate/) to know more about annotating in the UI.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once, you get the annotations, let's continue by retrieving the data from Argilla and format it as a dataset with the required data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rg_dataset_topic = client.datasets(\"topic-classification\")\n",
    "rg_dataset_fact_opinion = client.datasets(\"fact-opinion-classification\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "status_filter = rg.Query(filter=rg.Filter((\"response.status\", \"==\", \"submitted\")))\n",
    "\n",
    "submitted_topic = rg_dataset_topic.records(status_filter).to_list(flatten=True)\n",
    "submitted_fact_opinion = rg_dataset_fact_opinion.records(status_filter).to_list(\n",
    "    flatten=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_submitted(submitted):\n",
    "    return [\n",
    "        {\n",
    "            \"text\": r[\"text\"],\n",
    "            \"label\": r[\"label.responses\"][0],\n",
    "            \"id\": i,\n",
    "        }\n",
    "        for i, r in enumerate(submitted)\n",
    "    ]\n",
    "\n",
    "data_topic = format_submitted(submitted_topic)\n",
    "data_fact_opinion = format_submitted(submitted_fact_opinion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train your models\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In our case, we will fine-tune using SetFit. However, you can select the one that best fits your requirements.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Formatting the data\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next step will be to format the data to be compatible with SetFit. In the case of the topic classification, we will need to combine the synthetic data with the original data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_topic = hf_dataset.to_list()\n",
    "num = len(data_topic)\n",
    "\n",
    "data_topic.extend(\n",
    "    [\n",
    "        {\n",
    "            \"text\": r[\"text\"],\n",
    "            \"label\": id2str[r[\"label\"]],\n",
    "            \"id\": num + i,\n",
    "        }\n",
    "        for i, r in enumerate(hf_topic)\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we check the data distribution now, we can see that we have enough samples for each label to train our models.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Sci/Tech': 275, 'Business': 132, 'World': 98, 'Sports': 70})"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "labels = [record[\"label\"] for record in data_topic]\n",
    "Counter(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Fact-based': 86, 'Opinion-based': 54})"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "labels = [record[\"label\"] for record in data_fact_opinion]\n",
    "Counter(labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's create our training and validation datasets. The training dataset will gather 8 samples by label. In this case, the validation datasets will contain the remaining samples not included in the training datasets.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_and_split(dataset, label_column, num_samples):\n",
    "    train_dataset = sample_dataset(\n",
    "        dataset, label_column=label_column, num_samples=num_samples\n",
    "    )\n",
    "    eval_dataset = dataset.filter(lambda x: x[\"id\"] not in set(train_dataset[\"id\"]))\n",
    "    return train_dataset, eval_dataset\n",
    "\n",
    "\n",
    "dataset_topic_full = Dataset.from_list(data_topic)\n",
    "dataset_fact_opinion_full = Dataset.from_list(data_fact_opinion)\n",
    "\n",
    "train_dataset_topic, eval_dataset_topic = sample_and_split(\n",
    "    dataset_topic_full, \"label\", 8\n",
    ")\n",
    "train_dataset_fact_opinion, eval_dataset_fact_opinion = sample_and_split(\n",
    "    dataset_fact_opinion_full, \"label\", 8\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The actual training\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's train our models for each task! We will use [TaylorAI/bge-micro-v2](https://huggingface.co/TaylorAI/bge-micro-v2), available in the Hugging Face Hub. You can check the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select the best model for your use case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model_name, dataset, eval_dataset):\n",
    "    model = SetFitModel.from_pretrained(model_name)\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        train_dataset=dataset,\n",
    "    )\n",
    "    trainer.train()\n",
    "    metrics = trainer.evaluate(eval_dataset)\n",
    "    print(metrics)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num unique pairs = 768\n",
      "  Batch size = 16\n",
      "  Num epochs = 1\n",
      "  Total optimization steps = 48\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'embedding_loss': 0.1873, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.02}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running evaluation *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_runtime': 4.9767, 'train_samples_per_second': 154.318, 'train_steps_per_second': 9.645, 'epoch': 1.0}\n",
      "{'accuracy': 0.8333333333333334}\n"
     ]
    }
   ],
   "source": [
    "model_topic = train_model(\n",
    "    model_name=\"TaylorAI/bge-micro-v2\",\n",
    "    dataset=train_dataset_topic,\n",
    "    eval_dataset=eval_dataset_topic,\n",
    ")\n",
    "model_topic.save_pretrained(\"topic_classification_model\")\n",
    "model_topic = SetFitModel.from_pretrained(\"topic_classification_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num unique pairs = 144\n",
      "  Batch size = 16\n",
      "  Num epochs = 1\n",
      "  Total optimization steps = 9\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'embedding_loss': 0.2985, 'learning_rate': 2e-05, 'epoch': 0.11}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running evaluation *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_runtime': 0.8327, 'train_samples_per_second': 172.931, 'train_steps_per_second': 10.808, 'epoch': 1.0}\n",
      "{'accuracy': 0.9090909090909091}\n"
     ]
    }
   ],
   "source": [
    "model_fact_opinion = train_model(\n",
    "    model_name=\"TaylorAI/bge-micro-v2\",\n",
    "    dataset=train_dataset_fact_opinion,\n",
    "    eval_dataset=eval_dataset_fact_opinion,\n",
    ")\n",
    "model_fact_opinion.save_pretrained(\"fact_opinion_classification_model\")\n",
    "model_fact_opinion = SetFitModel.from_pretrained(\"fact_opinion_classification_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Voilà! The models are now trained and ready to be used. You can start making predictions to check the model's performance and add the new label. Optionally, you can continue using distilabel to generate additional data or Argilla to verify the quality of the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(model, input, labels):\n",
    "    model.labels = labels\n",
    "    prediction = model.predict([input])\n",
    "    return prediction[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Sci/Tech'"
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict(\n",
    "    model_topic, \"The new iPhone is expected to be released next month.\", labels_topic\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Opinion-based'"
      ]
     },
     "execution_count": 131,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict(\n",
    "    model_fact_opinion,\n",
    "    \"The new iPhone is expected to be released next month.\",\n",
    "    labels_fact_opinion,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusions\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we showcased the detailed steps to build a pipeline for generating text classification data using distilabel. You can customize this pipeline for your own use cases and share your datasets with the community through the Hugging Face Hub.\n",
    "\n",
    "We defined two text classification tasks—a topic classification task and a fact versus opinion classification task—and generated new data using various models via the serverless Hugging Face Inference API. Then, we curated the generated data with Argilla. Finally, we trained the models with SetFit using both the original and synthetic data."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "distilabel-tutorials",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
